# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""The rope embedding used within the model."""

from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor, defaults, defaults_like
from max.graph import Dim

from ..module import Module, module_dataclass


def theta(dim: int, base: float) -> Tensor:
    """Inverse-exponential frequencies for producing rotary positional
    embeddings.

    See 'Roformer: Enhanced Transformer with Rotary Embedding'
    (arxiv.org/pdf/2104.09864).

    In the paper,

    Args:
        dim: The embedding dimension. By convention each component
            of the complex valued embedding is considered its own dim
            in the embedding, so output has shape `dim // 2`.
        base: ...

    Returns:
        The 1d frequency tensor with shape (n // 2).
    """
    dtype, _ = defaults()
    # Use float64 for higher range in the exponential
    iota = Tensor.arange(dim, step=2, dtype=DType.float64)
    frequencies = base ** (-iota / dim)
    return frequencies.cast(dtype)


def embed(
    frequencies: Tensor,
    max_sequence_length: int,
) -> Tensor:
    """Computes the frequency tensor for complex exponentials
    in cis representation (cos(s) + i * sin(s)) for a given sequence length.

    Args:
        frequencies: Frequencies to embed in the cyclic space
        max_sequence_length: The number of positional embeddings to compute

    Returns:
        The embedded frequency tensor with shape
        (max_sequence_length, n / 2, 2).
    """
    with defaults_like(frequencies):
        t = Tensor.arange(max_sequence_length, dtype=DType.float64)
        # [max_seq_len*2, n // 2]
        freqs = F.outer(t, frequencies).cast(frequencies.dtype)
        # [max_seq_len*2, n // 2, 2]
        return F.stack([F.cos(freqs), F.sin(freqs)], axis=-1)


def positional_embedding(
    dim: int, base: float, max_sequence_length: int
) -> Tensor:
    """Computes rotary positional embeddings up to a specified
    sequence length.

    See 'Roformer: Enhanced Transformer with Rotary Embedding'
    (arxiv.org/pdf/2104.09864).

    Args:
        dim: The embedding dimension. By convention each component
            of the complex valued embedding is considered its own dim.
        base: Scaling factor for the frequency
        max_sequence_length: The number of positional embeddings to compute.

    Returns:
        RoPE positional embeddings of shape (max_sequence_length, dim / 2, 2).
    """
    return embed(theta(dim, base), max_sequence_length)


@module_dataclass
class RotaryEmbedding(Module):
    weight: Tensor
    #: Pre-computed embeddings of shape [max_sequence_length, n // 2, 2]

    @property
    def dim(self) -> int:
        return int(self.weight.shape[1]) * 2

    @property
    def max_sequence_length(self) -> int:
        return int(self.weight.shape[0])

    def __rich_repr__(self):
        yield "dim", self.dim
        yield "max_sequence_length", self.max_sequence_length

    @F.functional
    def __call__(self, x: Tensor, start_pos: Dim = Dim(0)) -> Tensor:
        """Applies rotary positional embeddings (RoPE) to `x`.

        seq_len is inferred from the shape of `x`.

        Args:
            x: Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim).
                x is interpreted as a complex number valued tensor where the
                `head_dim` dimension is alternating pairs of (real, imaginary)
                parts.
            start_pos: starting position of input tensor, defaults to 0 if None

        Returns:
            Input activation tensor with rotary positional embeddings applied and
            the same shape as `x`.
        """
        _, seq_len, _, _ = x.shape

        x_complex = F.as_interleaved_complex(x)
        freqs_cis = self.weight[start_pos : start_pos + seq_len, None, ...]
        return F.complex_mul(x_complex, freqs_cis).reshape(x.shape)


class TransposedRotaryEmbedding(RotaryEmbedding):
    @F.functional
    def __call__(self, x: Tensor, start_pos: Dim = Dim(0)) -> Tensor:
        """Applies rotary positional embeddings (RoPE) to `x`.

        The representation of `x` is transposed within the final dimension
        compared to traditional `RotaryEmbedding`.

        seq_len is inferred from the shape of `x`.

        Args:
            x: Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim).
                x is interpreted as a complex number valued tensor where the
                first half of `head_dim` are the real parts and the last half
                are the imaginary parts.
            start_pos: starting position of input tensor, defaults to 0 if None

        Returns:
            Input activation tensor with rotary positional embeddings applied and
            the same shape as `x`.
        """
        _, seq_len, _, _ = x.shape
        *rest, head_dim = x.shape

        x_complex = x.reshape((*rest, 2, head_dim // 2)).T
        freqs_cis = self.weight[start_pos : start_pos + seq_len, None, ...]
        return F.complex_mul(x_complex, freqs_cis).T.reshape(x.shape)
